--- Implementation of the n-body-problem shootout using hash map
module examples.LBodies 
        inline (sqr)  -- , Body.x, Body.y, Body.z, Body.vx, Body.vz, Body.vy, Body.mass)
    where

import frege.Prelude hiding(System)
import frege.prelude.Math (pi, sqrt)
import Data.HashMap as Map


sqr :: Double → Double
sqr x = x*x

data Body = Body {!x::Double, !y::Double, !z::Double,
                !vx::Double, !vy::Double, !vz::Double,
                !mass::Double }

derive Show Body

type System = HashMap Int Body

solar_mass = 4.0 * pi * pi
days_per_year = 365.24

offsetMomentum (body::Body) px py pz = Body {
    x=body.x, y=body.y, z=body.z, mass=body.mass,
    vx = negate px / solar_mass,
    vy = negate py / solar_mass,
    vz = negate pz / solar_mass }

startingState =  [sun, jupiter, saturn, uranus, neptun] where
            sun = Body 0.0  0.0  0.0  0.0  0.0  0.0  solar_mass

            jupiter = Body
                           4.84143144246472090e+00
                           (negate 1.16032004402742839e+00)
                           (negate 1.03622044471123109e-01)
                           (1.66007664274403694e-03 * days_per_year)
                           (7.69901118419740425e-03 * days_per_year)
                           (negate 6.90460016972063023e-05 * days_per_year)
                           (9.54791938424326609e-04 * solar_mass)

            saturn = Body
                        8.34336671824457987e+00
                        4.12479856412430479e+00
                        (negate 4.03523417114321381e-01)
                        (negate 2.76742510726862411e-03 * days_per_year)
                        (4.99852801234917238e-03 * days_per_year)
                        (2.30417297573763929e-05 * days_per_year)
                        (2.85885980666130812e-04 * solar_mass)

            uranus = Body
                        1.28943695621391310e+01
                        (negate 1.51111514016986312e+01)
                        (negate 2.23307578892655734e-01)
                        (2.96460137564761618e-03 * days_per_year)
                        (2.37847173959480950e-03 * days_per_year)
                        (negate 2.96589568540237556e-05 * days_per_year)
                        (4.36624404335156298e-05 * solar_mass)

            neptun = Body
                        1.53796971148509165e+01
                        (negate 2.59193146099879641e+01)
                        1.79258772950371181e-01
                        (2.68067772490389322e-03 * days_per_year)
                        (1.62824170038242295e-03 * days_per_year)
                        (negate 9.51592254519715870e-05 * days_per_year)
                        (5.15138902046611451e-05 * solar_mass)

initSystem = init 0.0 0.0 0.0 startingState
    where
        sun = head startingState
        init px py pz [] = offsetMomentum sun px py pz : tail startingState
        init px py pz (b:bs) = init
            (px + b.vx * b.mass)
            (py + b.vy * b.mass)
            (pz + b.vz * b.mass)
            bs

{-
    public double energy(){
        double dx, dy, dz, distance;
        double e = 0.0;

        for (int i=0; i < bodies.length; ++i) {
            e += 0.5 * bodies[i].mass *
               ( bodies[i].vx * bodies[i].vx
               + bodies[i].vy * bodies[i].vy
               + bodies[i].vz * bodies[i].vz );

            for (int j=i+1; j < bodies.length; ++j) {
                dx = bodies[i].x - bodies[j].x;
                dy = bodies[i].y - bodies[j].y;
                dz = bodies[i].z - bodies[j].z;

                distance = Math.sqrt(dx*dx + dy*dy + dz*dz);
                e -= (bodies[i].mass * bodies[j].mass) / distance;
            }
        }
        return e;
    }
-}

energy :: [Body] → Double
energy system = go 0.0 system
    where
        go e (b:bs) = go (e + m - sum forces) bs
            where
                m = 0.5 * b.mass * (sqr b.vx + sqr b.vy + sqr b.vz)
                forces = map (force b) bs
                force b1 b2 = b1.mass * b2.mass / distance b1 b2
        go e [] = e

distance :: Body → Body → Double
distance b1 b2 = sqrt(sqr dx + sqr dy + sqr dz)
    where
        !dx = b1.x - b2.x
        !dy = b1.y - b2.y
        !dz = b1.z - b2.z


move :: Body -> Body
move b = Body {mass=b.mass, vx=b.vx, vy=b.vy, vz=b.vz,
            x = b.x + (0.01 * b.vx),
            y = b.y + (0.01 * b.vy),
            z = b.z + (0.01 * b.vz)}

advance :: System -> [System -> System] -> System
advance !system work = fmap move (fold (\s\f -> f s) system work)

inner ∷ Int → Int → System → System
inner i j system  = insert i b' (insert j p'  system) 
    where
        !b = lookupDefault undefined i system
        !p = lookupDefault undefined j system
        !d  = sqrt(sqr dx + sqr dy + sqr dz)
        !dx = b.x - p.x
        !dy = b.y - p.y
        !dz = b.z - p.z
        !mag = 0.01 / (d*d*d)
        b' = Body {vx = b.vx - (dx * p.mass * mag),
                   vy = b.vy - (dy * p.mass * mag),
                   vz = b.vz - (dz * p.mass * mag),
                   mass = b.mass, x = b.x, y = b.y, z = b.z}
        p' = Body {vx = p.vx + (dx * b.mass * mag),
                   vy = p.vy + (dy * b.mass * mag),
                   vz = p.vz + (dz * b.mass * mag),
                   mass = p.mass, x = p.x, y = p.y, z = p.z}

loop 0   system = system
loop !n !system = go n system
    where
        !len = size system
        !work = [ inner i j | i ← [0..len-2], j ← [i+1..len-1] ]
        go !n !system 
            | n == 0    = system
            | otherwise = go (n-1) (advance system work) 

main (sn:_)
    | Right !n <- sn.int = do
            let bodies = initSystem
            stdout.printf "%.9f\n" (energy bodies)
            let endstate = loop n (Map.fromList (zip [0..] bodies))
            -- stdout.println (show endstate)
            stdout.printf "%.9f\n" (energy (map (\i -> lookupDefault undefined i endstate) [0..size endstate-1])) 
main _ = stderr.println "usage: java examples.LBodies steps"
